import os
import h5py
import numpy as np
import argparse
from IPython import embed
from pytorch3d.loss import chamfer_distance
import torch
from pytorch3d.ops import cubify, sample_points_from_meshes
from scipy.optimize import linear_sum_assignment

parser = argparse.ArgumentParser()

parser.add_argument('--save_name', type = str, required = True,
                    help='path to your trained DALL-E')
parser.add_argument('--category', type = str, required = True,
                    help='path to your trained DALL-E')
parser.add_argument('--emd', type = bool, default = False,
                    help='path to your trained DALL-E')
parser.add_argument('--ori', type = bool, default = False,
                    help='path to your trained DALL-E')

args = parser.parse_args()

if args.ori:
    ours_shape_h5 = h5py.File(os.path.join('./shape2prog/output/', args.category, 'shapes.h5'), 'r')
    ours_shapes = np.array(ours_shape_h5['data'])
else:
    ours_shape_h5 = h5py.File(os.path.join('./shape2prog/vqprogram_outputs/', 'test'+args.save_name, 'pred', args.category+'.h5'), 'r')
    ours_shapes = np.array(ours_shape_h5['shape'])

target_shape_h5 = h5py.File(os.path.join('./shape2prog/data/', args.category + '_testing.h5'), 'r')
target_shapes = np.array(target_shape_h5['data'])

ours_pc_list = []
for i in range(ours_shapes.shape[0]):
    m1 = cubify(torch.Tensor(ours_shapes[i]).unsqueeze(0),0.5)
    p1 = sample_points_from_meshes(m1)
    # ours_pc_list.append(np.expand_dims(p1,0))
    ours_pc_list.append(p1)
ours_pc = np.vstack(ours_pc_list)

target_pc_list = []
for i in range(target_shapes.shape[0]):
    m1 = cubify(torch.Tensor(target_shapes[i]).unsqueeze(0),0.5)
    p1 = sample_points_from_meshes(m1)
    # target_pc_list.append(np.expand_dims(p1,0))
    target_pc_list.append(p1)
target_pc = np.vstack(target_pc_list)

torch.set_printoptions(precision=7)
cd_dis = chamfer_distance(torch.Tensor(ours_pc).cuda(), torch.Tensor(target_pc).cuda())[0]
print('cd_dis:', cd_dis)
if args.emd:
    emd_dis = []
    dim = 10000
    for i in range(ours_pc.shape[0]):
        print('emd',i)
        q1 = ours_pc[i]
        q2 = target_pc[i]
        t1 = np.repeat(q1,dim,axis=0).reshape(dim,dim,3)
        t2 = np.swapaxes(np.repeat(q2,dim,axis=0).reshape(dim,dim,3), 0, 1)
        diff = t1-t2
        matrix = diff[:,:,0]*diff[:,:,0]+diff[:,:,1]*diff[:,:,1]+diff[:,:,2]*diff[:,:,2]
        row_ind, col_ind = linear_sum_assignment(matrix)
        diff2=q1 - q2[col_ind]
        # diff2 = q1 - q2
        emd_dis.append(np.mean(np.sqrt(diff2[:,0]*diff2[:,0]+diff2[:,1]*diff2[:,1]+diff2[:,2]*diff2[:,2])))
print('emd_dis:', np.mean(np.array(emd_dis)))

embed()
